import logging
import math

from typing import List, Optional

import torch
import torch.distributed._tensor.placement_types as placement_types
from torch.distributed._tensor.device_mesh import DeviceMesh, mesh_resources
from torch.distributed.distributed_c10d import (
    all_to_all,
    broadcast,
    get_global_rank,
    get_rank,
    get_world_size,
    GroupMember,
    ProcessGroup,
    scatter,
    Work,
)

logger = logging.getLogger(__name__)


# TODO: we need to migrate these APIs to be functional collectives


def mesh_scatter(
    output: torch.Tensor,
    scatter_list: List[torch.Tensor],
    mesh: DeviceMesh,
    mesh_dim: int = 0,
    async_op: bool = False,
) -> Optional[Work]:
    """
    scatter a list of tensors to a device mesh dimension. We by default
    use the first rank of the mesh dimension as the source of truth, i.e
    for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will
    scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank
    2 to rank 2/3.

    Args:
        output (torch.Tensor): the tensor to receive the scattered list.
        scatter_list (List[torch.Tensor]): the tensor list to be scattered.
        mesh_dim (int, optional): indicate which mesh dimension we want
            to scatter on, we by default choose the first rank on the
            mesh dimension as source of truth.

    Returns:
        A :class:`Work` object
    """
    # TODO: Ideally we should use the meta tensor way
    # (to register a meta kernel for the collective op)
    # so that it would avoid the communication. Need to
    # remove the check below once that is done.
    if output.is_meta:
        return None
    dim_group = mesh.get_dim_groups(mesh_dim)
    assert isinstance(dim_group, ProcessGroup)
    # src need to be global rank
    src_for_dim = 0

    if dim_group is not GroupMember.WORLD:
        src_for_dim = get_global_rank(dim_group, 0)

    if src_for_dim == get_rank():
        fut = scatter(
            output,
            scatter_list=scatter_list,
            src=src_for_dim,
            group=dim_group,
            async_op=async_op,
        )
    else:
        fut = scatter(
            output,
            scatter_list=None,
            src=src_for_dim,
            group=dim_group,
            async_op=async_op,
        )

    return fut


def mesh_broadcast(
    tensor: torch.Tensor,
    mesh: DeviceMesh,
    mesh_dim: int = 0,
    async_op: bool = False,
) -> Optional[Work]:
    """
    broadcast the tensor to a device mesh dimension. We by default
    use the first rank of the mesh dimension as the source of truth, i.e
    for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will
    broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2
    to rank 2/3.

    Args:
        tensor (torch.Tensor): tensor to broadcast.
        mesh_dim (int, optional): indicate which mesh dimension we want
            to scatter on, we by default choose the first rank on the
            mesh dimension as source of truth.

    Returns:
        A :class:`Work` object
    """
    # TODO: Ideally we should use the meta tensor way
    # (to register a meta kernel for the collective op)
    # so that it would avoid the communication. Need to
    # remove the check below once that is done.
    if tensor.is_meta:
        return None
    dim_group = mesh.get_dim_groups(mesh_dim)
    assert isinstance(dim_group, ProcessGroup)
    # src need to be global rank
    src_for_dim = 0
    if dim_group is not GroupMember.WORLD:
        src_for_dim = get_global_rank(dim_group, 0)

    return broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op)


# TODO: test uneven split on GLOO and NCCL
def mesh_all_to_all(
    output_tensor_list: List[torch.Tensor],
    input_tensor_list: List[torch.Tensor],
    mesh: DeviceMesh,
    mesh_dim: int = 0,
    async_op: bool = False,
) -> Optional[Work]:
    dim_group = mesh.get_dim_groups(mesh_dim)
    assert isinstance(dim_group, ProcessGroup)

    work = None
    # no direct dist.all_to_all support on 'gloo' so we manually do scatters
    if mesh.device_type == "cpu":
        logger.warning(
            "ProcessGroupGloo does not support all_to_all, falling back with scatters!"
        )
        # TODO: pull the handle of uneven case in #492
        dim_group_size = get_world_size(dim_group)
        for i in range(dim_group_size):
            # src need to be global rank
            src_for_dim = i
            if dim_group is not GroupMember.WORLD:
                src_for_dim = get_global_rank(dim_group, i)

            work = scatter(
                output_tensor_list[i],
                input_tensor_list if mesh.get_rank() == src_for_dim else [],
                group=dim_group,
                src=src_for_dim,
                async_op=async_op,
            )
    else:
        work = all_to_all(
            output_tensor_list,
            input_tensor_list,
            dim_group,
            async_op=async_op,
        )
    return work


def spec_to_bytes(spec: "placement_types.DTensorSpec") -> int:
    assert spec.tensor_meta is not None, "spec should have tensor meta defined!"
    return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape)


def get_bandwidth_factor(mesh: DeviceMesh) -> List[float]:
    # generate bandwidth factor for intra-host/inter-host communication pattern
    factors = [1.0] * mesh.ndim
    num_devices_per_host = mesh_resources.num_devices_per_host(mesh.device_type)

    num_devices = 1
    for mesh_dim in reversed(range(mesh.ndim)):
        num_devices *= mesh.size(mesh_dim)
        if num_devices <= num_devices_per_host:
            # magic number for intra-host communication bandwidth factor
            # TODO: see if we need to tweak this or offer a way for user
            # to specify the bandwidths
            factors[mesh_dim] = 0.2

    return factors


def allgather_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float:
    num_devices_on_mesh_dim = mesh.size(mesh_dim)
    bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
    # constant latency factor + bandwidth cost
    return (
        1
        + bandwidth_factor
        * num_bytes
        * (num_devices_on_mesh_dim - 1)
        / num_devices_on_mesh_dim
    )


def allreduce_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float:
    num_devices_on_mesh_dim = mesh.size(mesh_dim)
    bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
    # allreduce have 2x comm bytes compare to allgather/reduce_scatter
    return (
        1
        + 2
        * bandwidth_factor
        * num_bytes
        * (num_devices_on_mesh_dim - 1)
        / num_devices_on_mesh_dim
    )


def reduce_scatter_cost(
    num_bytes: float,
    mesh: DeviceMesh,
    mesh_dim: int,
) -> float:
    num_devices_on_mesh_dim = mesh.size(mesh_dim)
    bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim]
    # constant latency factor + bandwidth cost
    return (
        1
        + bandwidth_factor
        * num_bytes
        * (num_devices_on_mesh_dim - 1)
        / num_devices_on_mesh_dim
    )


def redistribute_cost(
    current_spec: "placement_types.DTensorSpec",
    target_spec: "placement_types.DTensorSpec",
) -> float:
    """
    This function returns the cost of redistribute from current to target DTensorSpec.

    NOTE:
    1. Only consider communication cost here, since computation costs for redistribute
       are quite trival (i.e. we only need to narrow or simple division)
    2. Only consider redistribute cost on same mesh, cross mesh communication cost is
       not quite needed for operator strategy estimation/selection.
    """
    if current_spec.mesh != target_spec.mesh:
        # make infinite cost if meshes are not same
        # TODO: see if we want to support this once there's cross mesh communication
        return float("inf")

    if current_spec.is_replicated():
        # short-cut:
        # comm cost is 0 if current spec is already full replication
        return 0.0

    mesh = current_spec.mesh
    cost = 0.0
    comm_bytes = spec_to_bytes(current_spec) / current_spec.num_shards
    # Transformation that considered for redistribute cost:
    # 1. allgather 2. alltoall
    # 3. allreduce 4. reduce_scatter
    for i, (current, target) in enumerate(
        zip(current_spec.placements, target_spec.placements)
    ):
        if current == target:
            continue
        if current.is_shard() and target.is_replicate():
            # allgather gives larger comm bytes
            comm_bytes *= mesh.size(i)
            # add up allgather comm cost
            cost += allgather_cost(comm_bytes, current_spec.mesh, i)
        elif current.is_shard() and target.is_shard():
            # should be alltoall comm, since we haven't implement it yet, add penalty
            # to favor allgather instead
            cost += allgather_cost(comm_bytes, current_spec.mesh, i) + 1.0
        elif current.is_partial() and target.is_replicate():
            # add up allreduce comm cost
            cost += allreduce_cost(comm_bytes, current_spec.mesh, i)
        elif current.is_partial() and target.is_shard():
            # add up reduce_scatter comm cost
            cost += reduce_scatter_cost(comm_bytes, current_spec.mesh, i)
            # after reduce_scatter the comm bytes for further collectives halved.
            comm_bytes /= mesh.size(i)
        elif current.is_shard() and target.is_partial():
            # ban shard -> partial as it does not make sense to perform
            # this redistribute
            return float("inf")

    return cost
